{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AMIE 实践\n", 
    "by 张文 [wenzhang2015@zju.edu.cn](wenzhang2015@zju.edu.cn) \n",
    “\n”,
    "在这个演示中，我们使用[AMIE](https://hal-imt.archives-ouvertes.fr/hal-01699866/document)对知识图谱进行规则挖掘。\n",
    "\n",
    "希望在这个demo中帮助大家了解知识图谱规则挖掘的作用原理和机制。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AMIE原理回顾\n",
    "### 规则\n",
    "知识图谱中的规则可以用以下形式表示：\n",
    "$$ B_1 \\land B_2 \\land ... \\land B_n \\Rightarrow H $$\n",
    "其中$B_1 \\land B_2 \\land ... \\land B_n$表示规则的body部分，有n个原子(atom)组成，$H$表示规则的head部分，由一个原子组成，每个原子A可以表示为$A = r(x,y)$的形式，其中$r$表示为原子包含的关系，$x,y$为变量。\n",
    "\n",
    "本demo中AMIE学习的规则为所有规则中的一个子集，即闭环的联通的规则，也可以叫做路径规则。可表示为：\n",
    "$$ r_1(x, z_1) \\land r_2(z_1, z_2) \\land ... \\land r_n(z_{n-1}, y) \\Rightarrow r(x,y) $$\n",
    "我们可以将其简化表示为$\\vec{B}\\Rightarrow r_0(x,y)$。\n",
    "\n",
    "如果规则中的所有变量替换为具体的实体并保证每个实例化后的atom都存在图谱中，这样规则实例化后的结果称为规则的一个grounding。\n",
    "\n",
    "### 规则的几个统计指标\n",
    "#### Support\n",
    "对于规则$rule$, support为在知识图谱中规则的grounding的个数, 可表示为\n",
    "$$ supp(\\vec{B}\\Rightarrow r_0(x,y)) = \\#(x,y): \\exists z_1, ..., z_n : \\vec{B}\\Rightarrow r_0(x,y)) = \\#(x,y) $$\n",
    "#### HC (head coverage)\n",
    "头覆盖度是另一个评价规则的指标，定义如下：\n",
    "$$ HC(\\vec{B}\\Rightarrow r_0(x,y)) = \\frac{suppo(\\vec{B}\\Rightarrow r_0(x,y))}{size(r)} $$\n",
    "其中$size(r)$表示图谱中关系存在的关于$r$的三元组数量，即$(x, r, y)$的数量。\n",
    "#### Confidence\n",
    "置信度的定义如下：\n",
    "$$ conf(\\vec{B}\\Rightarrow r_0(x,y)) = \\frac{suppo(\\vec{B}\\Rightarrow r_0(x,y))}{\\#(x,y): \\exists z_1, ..., z_n : \\vec{B}} $$\n",
    "#### PCA Confidence\n",
    "PCA置信度全称为 partial closed world assumption confidence,即局部封闭世界假设。在封闭世界假设(closed world assumption)中，图谱中不存在的三元组都被视为错误的，而在局部封闭世界假设中假设不存在知识图谱中的三元组不一定是错误的，有的缺失是图谱本身的不完整性造成的。PCA confidence的计算方式如下：\n",
    "$$ conf_{pca}(\\vec{B}\\Rightarrow r_0(x,y)) = \\frac{suppo(\\vec{B}\\Rightarrow r_0(x,y))}{\\#(x,y): \\exists z_1, ..., z_n : \\vec{B} \\land r(x, y^\\prime)} $$\n",
    "与confidence的定义不同，PCA confidence只将满足规则body部分且存在三元组$(x, r, y^\\prime)$的实例才计算入分母，由于考虑了知识图谱本身的不完整性，通常PCA confidence会比confidence更准确。\n",
    "\n",
    "## AMIE算法\n",
    "AMIE是一个基于搜索遍历和剪枝的规则挖掘算法，此类算法的重点在于剪枝策略。\n",
    "下面我们通过代码进行解释，首先导入需要的依赖包，本demo采用了多进程的挖掘方式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing import Queue, Process\n",
    "from collections import defaultdict\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import time\n",
    "from timeit import default_timer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "AMIE中的超参包含了三个，一个是最小的头覆盖度minHC，规则的最大长度maxLen, 以及最小置信度minConf，这里的用的是PCA confidence，为了方便起见，后文如果不做特殊说明，confidence都代指PCA confidence。这三个超参是为了规则搜索空间的剪枝服务的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = default_timer()\n",
    "data_dir = './data/WN18RR/'\n",
    "minHC = 0.001\n",
    "maxLen = 3\n",
    "minConf = 0.01\n",
    "epsilon = 1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义两个队列，一个是用于存储候选规则的队列ruleQ，一个是用于存储挖掘出的规则的队列outputQ。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ruleQ = Queue()\n",
    "outputQ = Queue()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义一些在剪枝过程中需要使用的字典，都是关于图谱中的信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "head = defaultdict(set)\n",
    "tail = defaultdict(set)\n",
    "hr2t = defaultdict(set)\n",
    "tr2h = defaultdict(set)\n",
    "ht2r = defaultdict(set)\n",
    "r2ht = defaultdict(set)\n",
    "head2num = defaultdict(int)\n",
    "tail2num = defaultdict(int)\n",
    "asHeadOfRelation = defaultdict(set)\n",
    "asTailOfRelation = defaultdict(set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先读取知识图谱中train.txt的信息，train.txt中每行为一个三元组，顺序为h r t并用tab符号隔开\n",
    "\n",
    "在读取数据的过程中我们自动添加了关系r的逆关系r_inv, 并为每个三元组(h,r,t)构建了一个逆三元组(t, r_inv, h)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#ENT:40943 \t#REL:22\n"
     ]
    }
   ],
   "source": [
    "def read2id(dir, REL=None):\n",
    "    item2id = dict()\n",
    "    id2item = dict()\n",
    "    items = []\n",
    "    with open(dir, 'r') as f:\n",
    "        for l in f.readlines():\n",
    "            ll = l.strip().split('\\t')\n",
    "            assert len(ll) == 2\n",
    "            item = ll[0]\n",
    "            id = ll[1]\n",
    "            item2id[item] = int(id)\n",
    "            id2item[int(id)] = item\n",
    "            items.append(int(id))\n",
    "            if REL is not None:\n",
    "                item2id[item+'_inv'] = int(id)+REL\n",
    "                id2item[int(id)+REL] = item+'_inv'\n",
    "                items.append(int(id)+REL)\n",
    "    num_item = len(item2id)\n",
    "    return item2id, id2item, num_item, items\n",
    "\n",
    "def record_triple(h,r,t):\n",
    "    head[r].add(h)\n",
    "    tail[r].add(t)\n",
    "    hr2t[(h,r)].add(t)\n",
    "    tr2h[(t,r)].add(h)\n",
    "    r2ht[r].add((h,t))\n",
    "    ht2r[(h,t)].add(r)\n",
    "    head2num[(r,h)] += 1\n",
    "\n",
    "trp_dir = data_dir + 'train.txt'\n",
    "ent2id, id2ent, num_ent, ents = read2id(data_dir + 'entity2id.txt')\n",
    "rel2id, id2rel, num_rel, rels = read2id(data_dir + 'relation2id.txt', REL=11)\n",
    "print('#ENT:%d \t#REL:%d'%(num_ent, num_rel))\n",
    "\n",
    "line_num = 0\n",
    "with open(trp_dir) as f:\n",
    "    for l in f.readlines():\n",
    "        line_num += 1\n",
    "        h,r,t = l.strip().split('\\t')\n",
    "        h = ent2id[h]\n",
    "        r = rel2id[r]\n",
    "        t = ent2id[t]\n",
    "        r_inv = r+num_rel\n",
    "        record_triple(h,r,t)\n",
    "        record_triple(t, r_inv, h)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化规则队列，初始的规则为只包含head的规则，在程序中我们用列表\\[$r, r_1, ..., r_n, 1$\\]表示闭合规则$r_1(x, z_1) \\land r_2(z_1, z_2) \\land ... \\land r_n(z_{n-1}, y) \\Rightarrow r(x,y)$， 用\\[$r, r_1, ..., r_n, 0$\\]非闭合规则$ r_1(x, z_1) \\land r_2(z_1, z_2) \\land ... \\land r_n(z_{n-1}, y^\\prime) \\Rightarrow r(x,y)$. 注意别表的最后一位数字用于表示规则是否闭合。初始化的规则均为非闭合规则。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for rel in id2rel.keys():\n",
    "    ruleQ.put([rel, 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先定义一些规则的学习指标函数，包括判断规则是否闭合，规则的长度，规则的support，PCA confidence以及HC指标。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closed(rule):\n",
    "    return rule[-1] == 1\n",
    "\n",
    "def length(rule):\n",
    "    return len(rule)-1\n",
    "\n",
    "def support(rule, PCA=False):\n",
    "    lenth = length(rule)\n",
    "    close = closed(rule)\n",
    "    if lenth == 2:\n",
    "        r0 = rule[0]\n",
    "        r1 = rule[1]\n",
    "        if close:\n",
    "            sup = len(r2ht[r0].intersection(r2ht[r1]))\n",
    "        else:\n",
    "            intersection_e = head[r0].intersection(head[r1])\n",
    "            sup = 0\n",
    "            for h in intersection_e:\n",
    "                num1 = head2num[(r0,h)]\n",
    "                num2 = head2num[(r1,h)]\n",
    "                sup += num1*num2\n",
    "        pca_num = np.sum(np.asarray([(len(hr2t[(h, r0)]) != 0) * head2num[(r1, h)] for h in head[r1]], dtype=np.int64))\n",
    "    elif close and lenth ==3:\n",
    "        r0 = rule[0]\n",
    "        r1 = rule[1]\n",
    "        r2 = rule[2]\n",
    "        sup = sum([len(hr2t[(h, r1)].intersection(tr2h[(t,r2)])) for h,t in r2ht[r0]])\n",
    "        pca_num = 0\n",
    "        for h,t in r2ht[r1]:\n",
    "            pca_num += np.sum(np.asarray([len(hr2t[(h, r0)])!=0 for tt in hr2t[(t,r2)]], dtype=np.int64))\n",
    "    else:\n",
    "        print(rule, length(rule))\n",
    "        raise NotImplementedError('rule do not calculate support')\n",
    "    if PCA:\n",
    "        return sup, pca_num\n",
    "    else:\n",
    "        return sup\n",
    "\n",
    "def pca_confidence(rule):\n",
    "    sup, pca_num = support(rule, PCA=True)\n",
    "    return (sup/pca_num)\n",
    "\n",
    "def HC(rule):\n",
    "    sup = support(rule)\n",
    "    headr = rule[0]\n",
    "    return sup/(len(r2ht[headr])+epsilon)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "conf_increase函数用于判断一条规则的confidence是否在其父规则的基础上有增加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conf_increase(rule):\n",
    "    if length(rule) == 2:\n",
    "        return True\n",
    "    parent = deepcopy(rule)\n",
    "    del parent[-2]\n",
    "    if closed(rule): parent[-1] = 0\n",
    "    parent_conf = pca_confidence(parent)\n",
    "    child_conf = pca_confidence(rule)\n",
    "    return child_conf>parent_conf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "acceptedForOutput函数用于判断一条规则是否满足输出规则的条件，\n",
    "\n",
    "以下规则不满足输出条件：（1）规则是非闭合的，（2规则的confidence小于设置的minConf，（3）规则的confidence在父规则的基础上没有增加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def acceptedForOutput(rule):\n",
    "    if not closed(rule):\n",
    "        return False\n",
    "    elif pca_confidence(rule) < minConf:\n",
    "        return False\n",
    "    elif not conf_increase(rule):\n",
    "        return False\n",
    "    else:\n",
    "        return True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "addingDanglingAtom 和 addingClosingAtom 两个函数是规则refine阶段的两个操作\n",
    "\n",
    "addingDanglingAtom在原有规则基础上增加一个atom，这个atom $r_i(x,y)$，其中$x$是规则中已存在的变量，$y$是规则中不存在的变量。\n",
    "\n",
    "addingClosingAtom在原有规则基础上增加一个atom，这个atom $r_i(x,y)$，其中$x,y$都是规则中已存在的变量，并且添加后是的规则闭合。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def addingDanglingAtom(rule):\n",
    "    assert rule[-1] != 1\n",
    "    if length(rule) == 2:\n",
    "        return []\n",
    "    for rel in rels:\n",
    "        rule_new = deepcopy(rule)\n",
    "        rule_new.insert(-1, rel)\n",
    "        if HC(rule_new) >= minHC:\n",
    "            yield rule_new\n",
    "def addingClosingAtom(rule):\n",
    "    assert rule[-1] != 1\n",
    "    for i, rel in enumerate(rels):\n",
    "        rule_new = deepcopy(rule)\n",
    "        rule_new.insert(-2, i)\n",
    "        rule_new[-1] = 1\n",
    "        if HC(rule_new) >= minHC:\n",
    "            yield rule_new"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义多进程的任务task。\n",
    "\n",
    "这个任务的定义主要包含以下操作：\n",
    "\n",
    "(1) 从规则队列inputQ中获取一个候选规则。\n",
    "\n",
    "(2) 通过acceptedForOutput函数判断这个规则是够满足输出的条件，如果满足，将其加入到输出规则队列。\n",
    "\n",
    "(3) 如果规则还未达到maxLen并且为非闭合的，则对规则进行refine，包括addingDanglingAtom，和addingClosingAtom两种操作，并将所有refine过程中新生成的规则填加到规则队列中。\n",
    "\n",
    "(4) 如果规则队列已经为空，结束当前规则挖掘进程。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def task(inputQ, outputQ):\n",
    "    while True:\n",
    "        if not inputQ.empty():\n",
    "            rule = inputQ.get()\n",
    "            if length(rule)!=1 and acceptedForOutput(rule): outputQ.put(rule)\n",
    "            if length(rule)<maxLen and not closed(rule):\n",
    "                for rul in addingDanglingAtom(rule):\n",
    "                    inputQ.put(rul)\n",
    "                for rul in addingClosingAtom(rule):\n",
    "                    inputQ.put(rul)\n",
    "        else:\n",
    "            print('task done')\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义多个规则挖掘进程，以task函数为目标，以ruleQ和outputQ为变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "process 0 start\n",
      "process 1 start\n",
      "process 2 start\n",
      "process 3 start\n",
      "process 4 start\n",
      "process 5 start\n",
      "process 6 start\n",
      "process 7 start\n",
      "process 8 start\n",
      "process 9 start\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n",
      "task done\n"
     ]
    }
   ],
   "source": [
    "num_worker = 10\n",
    "for i in range(num_worker):\n",
    "    p = Process(target=task, args=(ruleQ, outputQ))\n",
    "    print('process ' + str(i) + ' start')\n",
    "    p.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "规则挖掘完成对挖掘到的规则进行解析："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rule mining finished\n",
      "RULE \tSUPP\tHC\tCONF_pca\n",
      "_similar_to <== _similar_to \t80\t1.0000\t1.0000\n",
      "_member_of_domain_usage <== _member_of_domain_usage \t629\t1.0000\t1.0000\n",
      "_instance_hypernym <== _instance_hypernym \t2921\t1.0000\t1.0000\n",
      "_derivationally_related_form <== _derivationally_related_form \t29715\t1.0000\t1.0000\n",
      "_has_part <== _has_part \t4816\t1.0000\t1.0000\n",
      "_synset_domain_topic_of <== _synset_domain_topic_of \t3116\t1.0000\t1.0000\n",
      "_member_meronym <== _member_meronym \t7402\t1.0000\t1.0000\n",
      "_also_see <== _also_see \t1299\t1.0000\t1.0000\n",
      "_member_of_domain_region <== _member_of_domain_region \t923\t1.0000\t1.0000\n",
      "_verb_group <== _verb_group \t1138\t1.0000\t1.0000\n",
      "_synset_domain_topic_of <== _derivationally_related_form \t24\t0.0077\t0.0135\n",
      "_member_meronym <== _derivationally_related_form \t22\t0.0030\t0.1549\n",
      "_hypernym <== _hypernym \t34796\t1.0000\t1.0000\n",
      "_also_see <== _hypernym \t6\t0.0046\t0.0316\n",
      "_verb_group <== _hypernym \t17\t0.0149\t0.0208\n",
      "15 rules are learned\n",
      "with in 53.999620\n"
     ]
    }
   ],
   "source": [
    "def decodeRules():\n",
    "    num_rules = 0\n",
    "    print('RULE \tSUPP\tHC\tCONF_pca')\n",
    "    while not outputQ.empty():\n",
    "        rul = outputQ.get()\n",
    "        num_rules += 1\n",
    "        assert closed(rul)\n",
    "        rul_rs = deepcopy(rul)\n",
    "        del rul_rs[-1]\n",
    "        rule_str = ''\n",
    "        for i,r in enumerate(rul[:-1]):\n",
    "            r = id2rel[r]\n",
    "            rule_str += r\n",
    "            if i == 0:\n",
    "                rule_str += ' <== '\n",
    "            if i < length(rul)-1 and i != 0:\n",
    "                rule_str += ' & '\n",
    "        print('%s \t%.d\t%.4f\t%.4f'%(rule_str, support(rul),HC(rul), pca_confidence(rul)))\n",
    "    print('%d rules are learned'%(num_rules))\n",
    "\n",
    "while True:\n",
    "    if ruleQ.empty():\n",
    "        print('Rule mining finished')\n",
    "        decodeRules()\n",
    "        end = default_timer()\n",
    "        print('with in %3f'%(end-start))\n",
    "        break\n",
    "    else:\n",
    "        time.sleep(3)\n",
    "        continue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 后记\n",
    "\n",
    "由于WN18RR是一个关系比较少的数据集，所以挖掘出的规则数量较少，同学们有兴趣可以自己采用其他数据集试一试，同时也可以探索一些不同的minHC和minConf的设置对算法运行的影响。\n",
    "\n",
    "AMIE+采用了数据库查询语言SQL和SPARQL搭配数据库系统进行了算法增速，本demo中未涉及。同时demo中还省去了一些算法细节，有兴趣的同学可以自己阅读原文并查看其源码。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
